{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from advsecurenet.attacks.gradient_based import FGSM, LOTS, PGD\n",
    "from advsecurenet.shared.types.configs.attack_configs import (\n",
    "    FgsmAttackConfig,\n",
    "    LotsAttackConfig,\n",
    "    PgdAttackConfig,\n",
    ")\n",
    "from advsecurenet.shared.types.configs.attack_configs.attacker_config import (\n",
    "    AttackerConfig,\n",
    ")\n",
    "from advsecurenet.models.model_factory import ModelFactory\n",
    "from advsecurenet.datasets.dataset_factory import DatasetFactory\n",
    "from advsecurenet.attacks.attacker import Attacker\n",
    "from advsecurenet.dataloader.data_loader_factory import DataLoaderFactory\n",
    "from advsecurenet.shared.types.configs.preprocess_config import (\n",
    "    PreprocessConfig,\n",
    "    PreprocessStep,\n",
    ")\n",
    "from advsecurenet.shared.types.configs.device_config import DeviceConfig\n",
    "from advsecurenet.utils.adversarial_target_generator import AdversarialTargetGenerator\n",
    "from advsecurenet.datasets.targeted_adv_dataset import AdversarialDataset\n",
    "from advsecurenet.defenses.adversarial_training import AdversarialTraining\n",
    "from advsecurenet.shared.types.configs.defense_configs.adversarial_training_config import (\n",
    "    AdversarialTrainingConfig,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the model\n",
    "model = ModelFactory.create_model(\n",
    "    model_name=\"resnet18\", num_classes=10, pretrained=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Lets define the preprocessing configuration we want to use\n",
    "preprocess_config = PreprocessConfig(\n",
    "    steps=[\n",
    "        PreprocessStep(name=\"Resize\", params={\"size\": 32}),\n",
    "        PreprocessStep(name=\"CenterCrop\", params={\"size\": 32}),\n",
    "        PreprocessStep(name=\"ToTensor\"),\n",
    "        PreprocessStep(\n",
    "            name=\"ToDtype\", params={\"dtype\": \"torch.float32\", \"scale\": True}\n",
    "        ),\n",
    "        PreprocessStep(\n",
    "            name=\"Normalize\",\n",
    "            params={\"mean\": [0.485, 0.456, 0.406], \"std\": [0.229, 0.224, 0.225]},\n",
    "        ),\n",
    "    ]\n",
    ")\n",
    "\n",
    "# Define the dataset\n",
    "dataset = DatasetFactory.create_dataset(\n",
    "    dataset_type=\"cifar10\", preprocess_config=preprocess_config, return_loaded=False\n",
    ")\n",
    "train_data = dataset.load_dataset(train=True)\n",
    "test_data = dataset.load_dataset(train=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the dataloder\n",
    "dataloader = DataLoaderFactory.create_dataloader(dataset=train_data, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define the device config\n",
    "device = DeviceConfig(processor=\"mps\")\n",
    "\n",
    "# Define the fgsm config\n",
    "fgsm_config = FgsmAttackConfig(\n",
    "    targeted=False,\n",
    "    epsilon=0.1,\n",
    "    device=device,\n",
    ")\n",
    "\n",
    "# Now we can define the attack\n",
    "fgsm_attack = FGSM(config=fgsm_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pgd_config = PgdAttackConfig(\n",
    "    targeted=False,\n",
    "    epsilon=0.1,\n",
    "    alpha=0.01,\n",
    "    num_iter=10,\n",
    "    device=device,\n",
    ")\n",
    "\n",
    "pgd_attack = PGD(config=pgd_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adversarial_training_config = AdversarialTrainingConfig(\n",
    "    model=model,\n",
    "    models=[],  # no ensemble of models\n",
    "    attacks=[fgsm_attack, pgd_attack],  # ensemble of attacks\n",
    "    processor=device.processor,\n",
    "    train_loader=dataloader,\n",
    "    epochs=2,\n",
    ")\n",
    "\n",
    "adversarial_training = AdversarialTraining(config=adversarial_training_config)\n",
    "\n",
    "adversarial_training.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can aslo combine targeted and non-targeted attacks. In this case, we need to either provide the targets or let the `advsecurenet` to generate them. And finally, we need to use `AdversarialDataset` as the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create adversarial target generator\n",
    "target_generator = AdversarialTargetGenerator()\n",
    "\n",
    "# We need to generate target images and labels for the targeted attack\n",
    "target_images, target_labels = target_generator.generate_target_images_and_labels(\n",
    "    data=train_data, overwrite=True\n",
    ")\n",
    "\n",
    "# Again lets create a new dataset with the target images and labels\n",
    "targeted_dataset = AdversarialDataset(\n",
    "    base_dataset=train_data,\n",
    "    target_images=target_images,\n",
    "    target_labels=target_labels,\n",
    ")\n",
    "\n",
    "# And a new dataloader\n",
    "targeted_dataloader = DataLoaderFactory.create_dataloader(\n",
    "    dataset=targeted_dataset, batch_size=32\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "targeted_fgsm_config = FgsmAttackConfig(\n",
    "    targeted=True,\n",
    "    epsilon=0.1,\n",
    "    device=device,\n",
    ")\n",
    "\n",
    "targeted_fgsm = FGSM(config=targeted_fgsm_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It is also possible to use ensemble of models. If we use ensemble of models, in each batch one of the models is selected randomly and the attack is performed on that model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vgg16 = ModelFactory.create_model(model_name=\"vgg16\", num_classes=10, pretrained=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Here we are using the resnet18 as the main model that we aim to defend and the vgg16 as the model that we use to generate the adversarial examples\n",
    "# Also we are using 3 attacks, the fgsm, pgd and targeted_fgsm\n",
    "targeted_adversarial_training_config = AdversarialTrainingConfig(\n",
    "    model=model,\n",
    "    models=[vgg16],  # no ensemble of models\n",
    "    attacks=[fgsm_attack, pgd_attack, targeted_fgsm],  # ensemble of attacks\n",
    "    processor=device.processor,\n",
    "    train_loader=targeted_dataloader,\n",
    "    epochs=1,\n",
    ")\n",
    "\n",
    "adversarial_training = AdversarialTraining(config=targeted_adversarial_training_config)\n",
    "\n",
    "adversarial_training.train()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "adv2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
